86
Algorithms for Binary Neural Networks
We further show that the oscillation is factually controlled by the balanced parameter
attached to the reconstruction loss, providing a theoretical foundation for parameterizing
it in backpropagation. The oscillation only occurs when the gradient has a magnitude large
enough to change the sign of the latent weight. Consequently, we calculate the balanced
parameter based on the maximum magnitude of the weight gradient during each iteration,
leading to resilient gradients and effectively mitigating the weight oscillation.
3.9.1
Problem Formulation
Most existing implementations simply follow previous studies [199, 159] to optimize A and
latent weights W based on a nonparametric bilevel optimization as:
W∗= arg min
W
L(W; A∗),
(3.139)
s.t. αn∗= arg min
αn
∥wn −αn ◦bwn∥2
2,
(3.140)
where L(·) represents the training loss. Consequently, a closed-form solution of αn can be
derived by channelwise absolute mean (CAM) as αn
i =
∥wn
i,:,:,:∥1
M n
and M n = Cn
in ×Kn ×Kn.
For ease of representation, we use wn
i as an alternative to wn
i,:,:,: in the following. The
latent weight wn is updated using a standard gradient backpropagation algorithm, and its
gradient is calculated as:
δwn
i = ∂L
∂wn
i
= ∂L
∂ˆwn
i
∂ˆwn
i
∂wn
i
= αn
i
∂L
∂ˆwn
i
⊛1|wn
i |≤1,
(3.141)
where ⊛denotes the Hadmard product and ˆwn = αn ◦bwn.
Discussion. Equation (3.141) shows weight gradient mainly comes from the nonparametric
αn
i and the gradient
∂L
∂ˆwn
i .
∂L
∂ˆwn
i is automatically solved in backpropagation and becomes
smaller with network convergence; however, αn
i is often magnified by the trimodal distri-
bution [158]. Therefore, the weight oscillation originates mainly from αn
i . Given a single
weight wn
i,j(1 ≤j ≤M n) centering around zero, the gradient
∂L
∂wn
i,j is misleading, due to
the significant gap between wn
i,j and αn
i bwn
i,j. Consequently, bilevel optimization leads to
frequent weight oscillations. To address this issue, we reformulate traditional bilevel opti-
mization using a Lagrange multiplier and show that a learnable scaling factor is a natural
training stabilizer.
3.9.2
Method
We first give the learning objective as follows:
arg min
W,A
L(W, A) + LR(W, A),
(3.142)
where LR(W, A) is a weighted reconstruction loss and is defined as:
LR(W, A) = 1
2
N
n=1
Cout
i=1
γn
i ∥wn
i −αn
i bwn
i ∥2
2,
(3.143)